// Copyright 2020 The TensorStore Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef TENSORSTORE_INTERNAL_GRID_PARTITION_IMPL_H_
#define TENSORSTORE_INTERNAL_GRID_PARTITION_IMPL_H_

/// \file
/// Implementation details of grid_partition.h exposed for testing.

// IWYU pragma: private, include "tensorstore/internal/grid_partition.h"

#include <vector>

#include "absl/container/inlined_vector.h"
#include "absl/functional/function_ref.h"
#include "absl/status/status.h"
#include "tensorstore/array.h"
#include "tensorstore/index.h"
#include "tensorstore/index_interval.h"
#include "tensorstore/index_space/index_transform.h"
#include "tensorstore/index_space/internal/transform_rep.h"
#include "tensorstore/util/dimension_set.h"
#include "tensorstore/util/iterate.h"
#include "tensorstore/util/span.h"

namespace tensorstore {
namespace internal_grid_partition {

/// Precomputed data structure generated by `PrePartitionIndexTransformOverGrid`
/// that permits efficient iteration over the grid cell index vectors and
/// corresponding `cell_transform` values.
///
/// Logically, it is simply a container for a list `index_array_sets` of
/// IndexArraySet objects and a list `strided_sets` of StridedSet objects that
/// contain the precomputed data for all of the connected sets of input and grid
/// dimensions.  The order of the sets within these lists is not important for
/// correctness, although it does affect the iteration order.
class IndexTransformGridPartition {
 public:
  /// Represents a connected set containing only `single_input_dimension` edges
  /// within an IndexTransformGridPartition data structure.
  ///
  /// By definition, such connected sets must contain only a single input
  /// dimension.
  ///
  /// No precomputed data is required, as it is possible to efficiently iterate
  /// directly over the partitions.
  struct StridedSet {
    DimensionSet grid_dimensions;
    int input_dimension;
  };

  /// Represents a connected set containing `array` edges within an
  /// IndexTransformGridPartition data structure.
  ///
  /// The following information is precomputed:
  ///
  /// - the list `Hs` of partial grid cell index vectors;
  ///
  /// - for each partial grid cell, index arrays that provide a one-to-one map
  ///   from a synthetic one-dimensional space to the domains of each of the
  ///   input dimensions in the connected set.
  struct IndexArraySet {
    /// The grid dimension indices in this set.
    DimensionSet grid_dimensions;

    /// The input dimension indices in this set.
    DimensionSet input_dimensions;

    // TODO(jbms): Consider using absl::InlinedVector for `grid_cell_indices`
    // and `grid_cell_partition_offsets`.

    /// Row-major array of shape `[num_grid_cells, grid_dimensions.count()]`.
    /// Logically, this is a one-dimensional array of partial grid cell index
    /// vectors, sorted lexicpgrahically with respect to the components of the
    /// vectors (the second dimension of the array).  Each grid cell `i`
    /// corresponds to:
    ///
    ///     partitioned_input_indices[grid_cell_partition_offsets[i]:
    ///                               grid_cell_partition_offsets[i+1], :]
    std::vector<Index> grid_cell_indices;

    /// Array of partial input index vectors corresponding to the partial input
    /// domain of this connected set.  The vectors are partitioned by their
    /// corresponding partial grid cell index vector.  The shape is
    /// `[num_positions,input_dimensions.count()]`.
    SharedArray<Index, 2> partitioned_input_indices;

    /// Specifies the index into the first dimension of
    /// partitioned_input_indices for each grid cell in `grid_cell_indices`.
    /// Array of shape `[num_grid_cells]`.
    std::vector<Index> grid_cell_partition_offsets;

    /// Returns the index vector array of shape
    /// `{num_positions_in_partition, input_dim}` that maps the synthetic
    /// one-dimensional space to the domains of each input dimension.
    SharedArray<const Index, 2> partition_input_indices(
        Index partition_i) const;

    /// Returns the partial grid cell index vector corresponding to partition
    /// `i`.
    ///
    /// \returns A `span` of size `grid_dimensions.count()`.
    tensorstore::span<const Index> partition_grid_cell_indices(
        Index partition_i) const;

    /// Returns the number of partitions (partial grid cell index vectors).
    Index num_partitions() const {
      return static_cast<Index>(grid_cell_partition_offsets.size());
    }

    /// Returns the index of the partition, `partition_i`, for which
    /// `partition_grid_cell_indices(partition_i)` is equal to
    /// `grid_cell_indices`.
    ///
    /// On success returns `partition_i`, where
    /// `0 <= partition_i && partition_i < num_partitions()`.
    ///
    /// If there is no such partition, returns -1`.
    Index FindPartition(tensorstore::span<const Index> grid_cell_indices) const;
  };

  tensorstore::span<const IndexArraySet> index_array_sets() const {
    return index_array_sets_;
  }
  auto& index_array_sets() { return index_array_sets_; }

  tensorstore::span<const StridedSet> strided_sets() const {
    return strided_sets_;
  }
  auto& strided_sets() { return strided_sets_; }

  /// Returns the "cell transform" for the grid cell given by
  /// `grid_cell_indices`.
  ///
  /// The "cell transform" has a synthetic input domain and an output range that
  /// is exactly the subset of the domain of `full_transform` that maps to
  /// output positions contained in the specified grid cell.  See
  /// `grid_partition.h` for the precise definition.
  ///
  /// \param full_transform Must match the transform supplied to
  ///     `PrePartitionIndexTransformOverGrid`.
  /// \param grid_cell_indices The grid cell for which to compute the cell
  ///     transform.
  /// \param grid_output_dimensions Must match the value supplied to
  ///     `PrePartitionIndexTransformOverGrid`.
  /// \param get_grid_cell_output_interval Computes the output interval
  ///     corresponding to a given grid cell.  Must compute a result that is
  ///     consistent with that of the `output_to_grid_cell` function supplied to
  ///     `PrePartitionIndexTransformOverGrid`.
  IndexTransform<> GetCellTransform(
      IndexTransformView<> full_transform,
      tensorstore::span<const Index> grid_cell_indices,
      tensorstore::span<const DimensionIndex> grid_output_dimensions,
      absl::FunctionRef<IndexInterval(DimensionIndex grid_dim,
                                      Index grid_cell_index)>
          get_grid_cell_output_interval) const;

  /// The following members should be treated as private.

  /// Precomputed data for the strided connected sets.
  absl::InlinedVector<StridedSet, internal::kNumInlinedDims> strided_sets_;

  /// Precomputed data for the index array connected sets.
  std::vector<IndexArraySet> index_array_sets_;
};

/// Allocates the `cell_transform` and initializes the portions that are the
/// same for all grid cells.
///
/// \param info The preprocessed partitioning data.
/// \param full_transform The full transform.
/// \returns A non-null pointer to a partially-initialized transform from the
///     synthetic "cell" index space, of rank `cell_input_rank`, to the "full"
///     index space, of rank `full_input_rank`.
internal_index_space::TransformRep::Ptr<> InitializeCellTransform(
    const IndexTransformGridPartition& info,
    IndexTransformView<> full_transform);

/// Updates the output index maps and input domain in `cell_transform` to
/// correspond to `partition_i` of `index_array_set`.
///
/// \param index_array_set The index array set.
/// \param set_i The index of `index_array_set`, equal to the corresponding
///     input dimension of `cell_transform`.
/// \param partition_i The partition index.
/// \param cell_transform Non-null pointer to cell transform to update.
void UpdateCellTransformForIndexArraySetPartition(
    const IndexTransformGridPartition::IndexArraySet& index_array_set,
    DimensionIndex set_i, Index partition_i,
    internal_index_space::TransformRep* cell_transform);

/// Precomputes a data structure for partitioning an index transform by a
/// multi-dimensional grid.
///
/// The grid is a multi-dimensional grid where the mapping function
/// `output_to_grid_cell` maps from the {dimension, index} pair to a cell
/// index for that dimension.  The mapping from grid dimensions to output
/// dimensions of the index transform is specified by the
/// `grid_output_dimensions` array.
///
/// \param index_transform The index transform to partition.
/// \param grid_output_dimensions The sequence of output dimensions of
///     `index_transform` corresponding to the grid over which the index
///     transform is to be partitioned.
/// \param output_to_grid_cell Function to translate from output index to
///     a grid cell.
/// \param grid_partition[out] Will be initialized with the partitioning
///     information.
/// \error `absl::StatusCode::kInvalidArgument` if any input dimension of
///     `index_transform` has an unbounded domain.
/// \error `absl::StatusCode::kInvalidArgument` if integer overflow occurs.
/// \error `absl::StatusCode::kOutOfRange` if an index array contains an
///     out-of-bounds index.
absl::Status PrePartitionIndexTransformOverGrid(
    IndexTransformView<> index_transform,
    tensorstore::span<const DimensionIndex> grid_output_dimensions,
    absl::FunctionRef<Index(DimensionIndex grid_dim, Index output_index,
                            IndexInterval* cell_bounds)>
        output_to_grid_cell,
    IndexTransformGridPartition& grid_partition);

}  // namespace internal_grid_partition
}  // namespace tensorstore

#endif  // TENSORSTORE_INTERNAL_GRID_PARTITION_IMPL_H_
